Skip to content

Conversation

sanchitintel
Copy link

@sanchitintel sanchitintel commented Sep 18, 2025

Background

When multiple GEMMs are to be computed, each with its own canonical A, B, C, D matrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).

For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each expert (corresponding to a group) has an associated weight sized N * K, which essentially a column-major B matrix. All the B matrices are contiguous w.r.t. each other, i.e. their total size is num_groups * N * K. N, K are compile-time constants. M for each group is variable. All A matrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up the A matrix for that group.

MoEGEMM seems to be a natural candidate for leveraging GroupGEMM.

The problem

The cutlass GroupGEMM API is generic in that it requires pointers of A, B, C,D tensors pertaining to each group.
For launching the kernel, the CPU needs to provide a array of these GPU pointers (that array is also on the GPU).

However, for practical use-cases such as Mixture of Experts (each GroupGEMM group corresponds to oneMoE expert), such lists can't be conveniently pre-computed in advance (it's indeed possible to create it at the beginning of the kernel, and then synchronize across all workgroups, but that code can't be a part of generic Group GEMM).

Solution proposed in this PR

Provide only the base A, B, C, D pointers, and also pass N, K, so that the canonical A, B, C, D matrices' pointers for each group can be computed on-the-fly (a prefix sum algorithm to compute a cumulative sum of M might help but based on our experimentation, it doesn't seem to make much difference, as small M case is memory-bound, anyway).

To have minimal changes from the existing code, pass lists sized one instead of lists with size equal to the number of groups, as otherwise happens in the default case.
The PR adds a new kernel & a tile scheduler for MoEGEMM, while reusing existing MMA & epilogue collectives (but with modified code for A, B, C, D pointer computation).

We could instead add a template parameter to make these changes in the existing kernels and also use if constexpr to separate it from the default GroupGEMM. While the current implementation in this PR introduces duplication, the alternative would make the code messier.

Performance

With small M dimension for each GEMM problem, the performance is worse than that of large M dimension due to lower arithmetic intensity in the former case, but it's better than launching a separate kernel for each GEMM problem.

Caveat

The example just portrays one way to use the API.
Also, it has mostly been copy-pasted from an existing example, so it can be revised further.

@sanchitintel sanchitintel changed the title MoEGEMM based on cutlass GroupGEMM MoEGEMM as an extension of GroupGEMM Sep 19, 2025
@sanchitintel

This comment was marked as off-topic.

@sanchitintel sanchitintel marked this pull request as draft September 29, 2025 17:06
@airMeng
Copy link

airMeng commented Oct 15, 2025

@sanchitintel how is the progress of the PR?

@sanchitintel sanchitintel force-pushed the refactored_fused_moe_backup branch from ef117e2 to 1545982 Compare October 15, 2025 05:18
@sanchitintel
Copy link
Author

Hi @airMeng,

This PR doesn't have updated code with performance optimizations, which I otherwise have locally. It does have the updated API interface in the example, though.

However, if @Antonyvance & the cutlass team wouldn't want to have MoE GEMM as a separate kernel, then it's better to port it to the applications directory instead.

Can you please explain why you asked?

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants